{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# Epochs are set here so we can test these notebooks in our CI workflow.\n",
    "epochs = 25"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1>Federated CIFAR 10 Example</h1>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This example demonstrates how training a simple Image classifier written in PyTorch could be trained using federated learning with PySyft. We distribute the image data to two workers Bob and Alice to whom the model is sent and trained. Upon training the model the trained model is sent back to the owner of the model and used to make predictions. \n",
    "\n",
    "\n",
    "Hrishikesh Kamath - GitHub: @<a href=\"http://github.com/kamathhrishi\">kamathhrishi</a>\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Import required libraries\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import syft as sy  # <-- NEW: import the Pysyft library"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Setting of Learning Task</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Arguments():\n",
    "    def __init__(self):\n",
    "        self.batch_size = 64\n",
    "        self.test_batch_size = 1000\n",
    "        self.epochs = epochs\n",
    "        self.lr = 0.01\n",
    "        self.momentum = 0.5\n",
    "        self.no_cuda = True\n",
    "        self.seed = 1\n",
    "        self.log_interval = 200\n",
    "        self.save_model = False\n",
    "\n",
    "args = Arguments()\n",
    "\n",
    "use_cuda = not args.no_cuda and torch.cuda.is_available()\n",
    "\n",
    "torch.manual_seed(args.seed)\n",
    "\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "\n",
    "kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h3>Initialize Workers</h3>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hook = sy.TorchHook(torch)  # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning\n",
    "bob = sy.VirtualWorker(hook, id=\"bob\")  # <-- NEW: define remote worker bob\n",
    "alice = sy.VirtualWorker(hook, id=\"alice\")  # <-- NEW: and alice"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1>Load and Distribute Dataset</h1>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data():\n",
    "    \n",
    "    '''<--Load CIFAR dataset from torch vision module distribute to workers using PySyft's Federated Data loader'''\n",
    "    \n",
    "\n",
    "    federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader \n",
    "    datasets.CIFAR10('../data', train=True, download=True,\n",
    "                   transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),\n",
    "                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "                   ]))\n",
    "    .federate((bob, alice)), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset\n",
    "    batch_size=args.batch_size, shuffle=True, **kwargs)\n",
    "\n",
    "    test_loader = torch.utils.data.DataLoader(\n",
    "    datasets.CIFAR10('../data', train=False, transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),\n",
    "                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "                   ])),\n",
    "    batch_size=args.test_batch_size, shuffle=True, **kwargs)\n",
    "    \n",
    "    return federated_train_loader,test_loader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1>Define Neural Network Model</h1>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 6, 5)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = x.view(-1, 16 * 5 * 5)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return F.log_softmax(x, dim=1)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1>Train Function</h1>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def train(args, model, device, train_loader, optimizer, epoch):\n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now it is a distributed dataset\n",
    "        model.send(data.location) # <-- NEW: send the model to the right location\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = F.nll_loss(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        model.get() # <-- NEW: get the model back\n",
    "        if batch_idx % args.log_interval == 0:\n",
    "            loss = loss.get() # <-- NEW: get the loss back\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                epoch, batch_idx * args.batch_size, len(train_loader) * args.batch_size, #batch_idx * len(data), len(train_loader.dataset),\n",
    "                100. * batch_idx / len(train_loader), loss.item()))\n",
    "            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1>Test Function</h1>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(args, model, device, test_loader):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n",
    "            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability \n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "\n",
    "    print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        test_loss, correct, len(test_loader.dataset),\n",
    "        100. * correct / len(test_loader.dataset)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1>Train Neural Network</h1>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#<--Load federated training data and test data\n",
    "federated_train_loader,test_loader=load_data()\n",
    "\n",
    "#<--Create Neural Network model instance\n",
    "model = Net().to(device)\n",
    "optimizer = optim.SGD(model.parameters(), lr=args.lr) #<--TODO momentum is not supported at the moment\n",
    "\n",
    "#<--Train Neural network and validate with test set after completion of training every epoch\n",
    "for epoch in range(1, args.epochs + 1):\n",
    "    train(args, model, device, federated_train_loader, optimizer, epoch)\n",
    "    test(args, model, device, test_loader)\n",
    "\n",
    "if (args.save_model):\n",
    "    torch.save(model.state_dict(), \"cifar10_cnn.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Well Done!\n",
    "\n",
    "And voilà! We now are training a real world Learning model using Federated Learning! As you observed, we modified 10 lines of code to upgrade the official Pytorch example on CIFAR10 to a real Federated Learning setting!\n",
    "\n",
    "## Shortcomings of this Example\n",
    "\n",
    "Currently we do not support momentum argument in the optimizer due to which it took more number of epochs than required to train a CNN on CIFAR10 dataset. \n",
    "\n",
    "Of course, there are dozen of improvements we could think of. We would like the computation to operate in parallel on the workers, to update the central model every `n` batches only, to reduce the number of messages we use to communicate between workers, etc.\n",
    "\n",
    "On the security side it still has some major shortcomings. Most notably, when we call `model.get()` and receive the updated model from Bob or Alice, we can actually learn a lot about Bob and Alice's training data by looking at their gradients. We could **average the gradient across multiple individuals before uploading it to the central server**, like we did in Part 4 of tutorials section."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Congratulations!!! - Time to Join the Community!\n",
    "\n",
    "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!\n",
    "\n",
    "\n",
    "### Star PySyft on GitHub\n",
    "\n",
    "The easiest way to help our community is just by starring the repositories! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### Pick our tutorials on GitHub!\n",
    "\n",
    "We made really nice tutorials to get a better understanding of what Federated and Privacy-Preserving Learning should look like and how we are building the bricks for this to happen.\n",
    "\n",
    "- [Checkout the PySyft tutorials](https://github.com/OpenMined/PySyft/tree/master/examples/tutorials)\n",
    "\n",
    "\n",
    "### Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community! \n",
    "\n",
    "- [Join slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### Join a Code Project!\n",
    "\n",
    "The best way to contribute to our community is to become a code contributor! If you want to start \"one off\" mini-projects, you can go to PySyft GitHub Issues page and search for issues marked `Good First Issue`.\n",
    "\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### Donate\n",
    "\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "- [Donate through OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
